import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display
import re
import pandas as pd
import IPython
import numpy as np
def animate_population_pyramid(mmodel, interval=300, save_path=None, auto_show=True, age_limits=[20, 65]):
"""
Animated population pyramid (0–100) colored by 3 age groups defined by age_limits.
Displays population shares in legend, vertical divider, and 'Male'/'Female' labels above chart.
Parameters
----------
mmodel : ModelFlow model instance
Must contain population variables 'pop__male*__AGE_*' and 'pop__female*__AGE_*'.
interval : int
Delay between frames (ms).
save_path : str, optional
If provided, saves animation (.gif or .mp4).
auto_show : bool
If True, displays inline in Jupyter.
age_limits : list of two ints, optional
The two cutoff ages between groups. Example: [20, 65] → groups (0–20), (21–65), (65+)
"""
df = mmodel.basedf
years = list(df.index)
# --- ModelFlow pattern filters ---
male_cols_sorted = sorted(mmodel['pop__male*__AGE_*'].names, key=sort_key)
female_cols_sorted = sorted(mmodel['pop__female*__AGE_*'].names, key=sort_key)
if not male_cols_sorted or not female_cols_sorted:
raise ValueError("No population columns found for male/female in model data.")
# --- Extract ages ---
ages_sorted = [int(re.search(r'AGE_(\d+)', c).group(1)) for c in male_cols_sorted]
# --- Define age groups dynamically ---
limit1, limit2 = age_limits
def age_group(age):
if age <= limit1:
return 0
elif age <= limit2:
return 1
else:
return 2
group_idx = np.array([age_group(a) for a in ages_sorted])
group_colors = ["#1f77b4", "#2ca02c", "#d62728"] # blue, green, red
male_colors = [group_colors[g] for g in group_idx]
female_colors = [group_colors[g] for g in group_idx]
# --- Helper to get population arrays ---
def get_values(year):
m = df.loc[year, male_cols_sorted].to_numpy()
f = df.loc[year, female_cols_sorted].to_numpy()
return m, f
male_init, female_init = get_values(years[0])
# --- Compute population shares for legend ---
def get_group_shares(year):
m, f = get_values(year)
total = m.sum() + f.sum()
shares = []
for i in range(3):
group_total = (m[group_idx == i].sum() + f[group_idx == i].sum())
shares.append(group_total / total * 100)
return shares
shares = get_group_shares(years[0])
# --- Figure setup ---
fig, ax = plt.subplots(figsize=(10, 8))
ax.set_xlabel('Population')
ax.set_ylabel('Age')
ax.set_title('Population Pyramid by Age Group', fontsize=14, pad=25)
ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f"{abs(int(x)):,}"))
legend_labels = [
f"0–{limit1} ({shares[0]:.1f}%)",
f"{limit1+1}–{limit2} ({shares[1]:.1f}%)",
f"{limit2+1}+ ({shares[2]:.1f}%)"
]
legend_handles = [
plt.Rectangle((0,0),1,1, color=group_colors[i], label=legend_labels[i]) for i in range(3)
]
legend = ax.legend(handles=legend_handles, title="Age Groups", loc="upper right")
male_bar = ax.barh(ages_sorted, -male_init, color=male_colors)
female_bar = ax.barh(ages_sorted, female_init, color=female_colors)
year_text = ax.text(0.02, 0.95, str(years[0]), transform=ax.transAxes,
fontsize=16, fontweight='bold', ha='left', va='top')
max_pop = max(df[male_cols_sorted + female_cols_sorted].max())
ax.set_xlim(-max_pop * 1.1, max_pop * 1.1)
# --- Add vertical line at x = 0 ---
ax.axvline(0, color="black", linewidth=1.2)
# --- Add 'Male' and 'Female' labels ABOVE plot ---
# Use figure coordinates to ensure fixed position
fig.text(0.25, 0.94, "Male", ha="center", va="bottom",
fontsize=14, fontweight='bold', color="steelblue")
fig.text(0.75, 0.94, "Female", ha="center", va="bottom",
fontsize=14, fontweight='bold', color="lightcoral")
# --- Animation update ---
def update(frame):
year = years[frame]
male_values, female_values = get_values(year)
for bar, val in zip(male_bar, -male_values):
bar.set_width(val)
for bar, val in zip(female_bar, female_values):
bar.set_width(val)
year_text.set_text(str(year))
shares = get_group_shares(year)
for i, txt in enumerate(legend.get_texts()):
txt.set_text(f"{legend_labels[i].split('(')[0]}({shares[i]:.1f}%)")
return []
anim = FuncAnimation(fig, update, frames=len(years), interval=interval, blit=False, repeat=True)
# --- Save animation if requested ---
if save_path:
if save_path.endswith(".gif"):
anim.save(save_path, writer="pillow", fps=5)
elif save_path.endswith(".mp4"):
anim.save(save_path, writer="ffmpeg", fps=5)
print(f"✅ Animation saved to: {save_path}")
# --- Inline display (Notebook 7 compatible) ---
if auto_show and IPython.get_ipython() is not None:
plt.close(fig)
display(HTML(anim.to_jshtml()))
return anim
#anim = animate_population_pyramid(mpopulation,save_path='graph/test1.mp4')